#include <stdint.h>
#include <string.h>
#include <time.h>

#include "esmd_types.h"
#include "cell.h"
#include "io_dcd.h"
#include "comm.h"
#include "memory_cpp.hpp"
void dcd_header_init(dcd_file_t *dcd, const char *path, cellgrid_t *grid, mpp_t *mpp, real dt) {
  memset(&(dcd->header), 0, sizeof(dcd_header_t));
  dcd->header.nbhead_begin = 84;
  strcpy(dcd->header.title, "CORD");
  dcd->header.numframes = 0;
  dcd->header.firststep = 0;
  dcd->header.framestepcnt = 1;
  dcd->header.numsteps = 0;
  dcd->header.timestep = dt;
  dcd->header.iscell = 1;
  dcd->header.charmmversion = 24;
  dcd->header.nbhead_end = 84;
  dcd->header.nbtitle_begin = 164;
  dcd->header.numtitle = 2;
  snprintf(dcd->header.title_str, sizeof(dcd->header.title_str), "REMARKS FILENAME=\"%s\"", path);
  time_t abstime = time(NULL);
  struct tm *tmbuf;
  tmbuf = localtime(&abstime);
  strftime(dcd->header.create_str, sizeof(dcd->header.create_str), "REMARKS MDIO DCD file created %d %b %Y at %H:%M", tmbuf);
  dcd->header.nbtitle_end = 164;
  dcd->header.nbnatoms_begin = 4;
  long natoms = 0;
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell){
    natoms += cell->natom;
  }
  comm_allreduce(&natoms, 1, MPI_LONG, MPI_SUM, mpp);
  //long natoms;
  dcd->header.natoms = natoms;
  dcd->header.nbnatoms_end = 4;
  if (mpp->pid == 0) {
    dcd->file = fopen(path, "wb");
    fwrite(&dcd->header, sizeof(dcd_header_t), 1, dcd->file);
  }
  // esmd_imalloc(dcd->coordbuf, 3 * natoms + 6, "dcd/coord buf");
  dcd->coordbuf = esmd::allocate<int32_t>(3 * natoms + 6, "dcd/coord buf");
  dcd->coordbuf[0] = 4 * natoms;
  dcd->coordbuf[natoms + 1] = 4 * natoms;
  dcd->coordbuf[natoms + 2] = 4 * natoms;
  dcd->coordbuf[2 * natoms + 3] = 4 * natoms;
  dcd->coordbuf[2 * natoms + 4] = 4 * natoms;
  dcd->coordbuf[3 * natoms + 5] = 4 * natoms;
}

void dcd_write_frame(dcd_file_t *dcd, cellgrid_t *grid, mpp_t *mpp) {
  dcd->cell.nbcell_begin = 48;
  dcd->cell.unitcell[0] = grid->gbox.hi.x - grid->gbox.lo.x;
  dcd->cell.unitcell[1] = 0;
  dcd->cell.unitcell[2] = grid->gbox.hi.y - grid->gbox.lo.y;
  dcd->cell.unitcell[3] = 0;
  dcd->cell.unitcell[4] = 0;
  dcd->cell.unitcell[5] = grid->gbox.hi.z - grid->gbox.lo.z;
  dcd->cell.nbcell_end = 48;
  dcd->cell.dummy2 = 0;
  
  float *x = (float*)dcd->coordbuf + 1;
  float *y = (float*)dcd->coordbuf + dcd->header.natoms + 3;
  float *z = (float*)dcd->coordbuf + 2 * dcd->header.natoms + 5;

  for (int i = 0; i < dcd->header.natoms; i ++) {
    x[i] = 0;
    y[i] = 0;
    z[i] = 0;
  }
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell){
    for (int i = 0; i < cell->natom; i ++){
      x[cell->tag[i]] = cell->x[i].x - grid->gbox.lo.x + cell->basis.x;
      y[cell->tag[i]] = cell->x[i].y - grid->gbox.lo.y + cell->basis.y;
      z[cell->tag[i]] = cell->x[i].z - grid->gbox.lo.z + cell->basis.z;
    }
  }

  comm_reduce(x, dcd->header.natoms, MPI_FLOAT, MPI_SUM, mpp);
  comm_reduce(y, dcd->header.natoms, MPI_FLOAT, MPI_SUM, mpp);
  comm_reduce(z, dcd->header.natoms, MPI_FLOAT, MPI_SUM, mpp);
  
  dcd->header.numframes ++;
  dcd->header.numsteps += dcd->header.framestepcnt;
  if (mpp->pid == 0) {
    fwrite(&dcd->cell.nbcell_begin, 56, 1, dcd->file);
    fwrite(dcd->coordbuf, sizeof(int32_t), 3 * dcd->header.natoms + 6, dcd->file);
  }
}

void dcd_close(dcd_file_t *dcd, mpp_t *mpp) {
  if (mpp->pid == 0) {
    fseek(dcd->file, 0, SEEK_SET);
    fwrite(&dcd->header, sizeof(dcd_header_t), 1, dcd->file);
    fclose(dcd->file);
  }
}